Using GraphSAGE to Generate Embeddings for Unseen Data

The GraphSAGE (SAmple and aggreGatE) algorithm [13] emerged in 2017 as a method for not only learning useful vertex embeddings, but also for predicting vertex embeddings on unseen vertices. This allows powerful high-level feature vectors to be produced for vertices which were not seen at train time; enabling us to effectively work with dynamic graphs, or very large graphs (>100, 000 vertices).

A GraphSAGE net is built up of k convolutional layers, called SageConv layers by the authors. Like other GNNs, they use a message-passing algorithm to combine neighbourhood features for each node. These features are then aggregated using a reduce function like max pool or mean.

Setup

Here we load required libraries, define paths to data, and define some helper functions. Feel free to skip this section.

In [1]:
import numpy as np
import networkx as nx

from IPython.display import HTML
import matplotlib.animation as animation
import matplotlib.pyplot as plt

import dgl
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
Using backend: pytorch

Datasets

In this example we use the Cora dataset (see Figure 19) as provided by the deep learning library DGL.

The Cora dataset is oft considered ‘the MNIST of graph-based learning’ and consists of 2708 scientific publications (vertices), each classified into one of seven subfields in AI (or classes). Each vertex has a 1433 element binary feature vector, which indicates if each of the 1433 designated words appeared in the publication.

In [2]:
# To demonstrate let's use the Cora dataset
# DGL provides an api to access this and other datasets.
from dgl.data import citation_graph
data = citation_graph.CoraDataset()
In [3]:
features = data.features
# that will download and cache the data for use later
# let's investigate
n_features = int(features.shape[1])
n_nodes = int(features.shape[0])
n_edges = data.graph.number_of_edges()

print(f'There are {n_nodes} nodes and {n_edges} edges')
print(f'Each node has {n_features} features')

# let's look at the labels, the classification target
labels = data.labels
n_classes = labels.max() + 1
print("There are {} classes".format(n_classes))
plt.hist(labels.flatten()[:20000], bins=n_classes)
There are 2708 nodes and 10556 edges
Each node has 1433 features
There are 7 classes
Out[3]:
(array([298., 418., 818., 426., 217., 180., 351.]),
 array([0.        , 0.85714286, 1.71428571, 2.57142857, 3.42857143,
        4.28571429, 5.14285714, 6.        ]),
 <a list of 7 Patch objects>)
In [4]:
# DGL datasets come preprepared with train/test/val splits, in the form of index masks
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
print(int(train_mask.sum()), 'train samples')
print(int(val_mask.sum()), 'validation samples')
print(int(test_mask.sum()), 'test samples')
140 train samples
300 validation samples
1000 test samples
In [5]:
# Let's convert the data from numpy arrays to the required pytorch tensors. 
features = torch.FloatTensor(features)
labels = torch.LongTensor(labels)
train_mask = torch.BoolTensor(train_mask)
val_mask = torch.BoolTensor(val_mask)
test_mask = torch.BoolTensor(test_mask)
In [6]:
# If we are using the gpu, we can send the arrays to gpu memory.
gpu = 0
if gpu >= 0:
    torch.cuda.set_device(gpu)
    features = features.cuda()
    labels = labels.cuda()
    train_mask = train_mask.cuda()
    val_mask = val_mask.cuda()
    test_mask = test_mask.cuda()   

A subgraph of the Cora dataset. The full Cora graph has N = 2708 and M = 5429. Note the many vertices with few incident edges (low degree) as compared to the few vertices with many incident edges (high degree).

In [7]:
# DGL datasets come with a pre-initialised networkx graph
g = data.graph
# first remove any existing self-loops, because graphSAGE employs
# its own way of dealing with self-loops in the forward pass
g.remove_edges_from(nx.selfloop_edges(g))
# and lets recalc the num of edges for later
n_edges = g.number_of_edges()
# for simplicity lets convert the graph to an undirected one
g = g.to_undirected()

# with a networkx graph we can do some plotting
# lets just plot a fraction of the nodes
g_copy = g.copy()
g_copy.remove_nodes_from(range(500, n_nodes))
nx.draw(g_copy, node_size=10, alpha=0.6, arrows=False, edge_color='purple')
In [8]:
# We can build a trainable GNN out of this networkx graph with dgl/
# The DGLGraph can take a networkx graph as input
g = dgl.DGLGraph(g)

Architecture and initial experiments

We'll start by setting up our own layers, models, and training routines.

In [9]:
# Like all layers and neural nets in pytorch we will inherit the Module class
class MeanAggSageLayer(nn.Module):
    def __init__(self, n_features_in, n_features_out):
        super(MeanAggSageLayer, self).__init__()
        # number of features coming in to this layer. If this is the first layer, 
        # this will be the amount of features per node
        self._in = n_features_in
        # the number of output features from this layer,
        # In the final layer of the GraphSAGE net this will equal n_classes 
        self._out = n_features_out
        # create a linear transformation between the input channels and the output.
        # These nn.Linear objects are shortcuts to hold the weights and biases
        # that are learnt through backpropogation, and applied
        # to incoming features. We will have one for self nodes 
        self.fc_self = nn.Linear(self._in, self._out)
        # and one for neighbour nodes 
        self.fc_neigh = nn.Linear(self._in, self._out)
        # we will initialise the weights with xavier_unform random
        # sampling, another name for Glorot uniform used in the original
        # graphsage paper
        gain = nn.init.calculate_gain('relu')   # sqrt(2)
        # set the gain appropriately for our activation function 
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
    
    def forward(self, graph, features):
        """
        The following code is DGL's way of using the graph class
        to facilitate message passing. The equivalent code in pure pytorch
        operating instead on the adjacency matrix adj and the feature matrix x would be:
            
            def forward(self, x, adj):
                return adj.matmul(x, reduce='mean') @ self_weights + x @ neigh_weights + bias
        
        """
        # set the incoming features matrix as the input to this layer 'h'
        graph.srcdata['h'] = features
        # create 2 user defined functions, the first to collect features 
        # from the src nodes 'h', send along edges 'm', and aggregate them at the 
        # destination nodes (the neighbours)
        features_from_src_nodes = dgl.function.copy_src('h', 'm')
        aggregation_at_dst_nodes = dgl.function.mean('m', 'neigh')
        # graph.update_all is a helper function to send the first function
        # along the edges and recieve the second function at the
        # destination nodes
        graph.update_all(features_from_src_nodes, aggregation_at_dst_nodes)
        # now we can get our aggregated neighbourhood features
        h_neigh = graph.dstdata['neigh']
        # and combine them with the src features (self loops)
        # fc_self(features) is equivalent to features @ weights + biases
        output = self.fc_self(features) + self.fc_neigh(h_neigh)
        # lastly we add a nonlinearity to the output enabling backpropogation
        output = F.relu(output)
        return output
    

The only method we need will be the 'self.forward' method (forward pass). The backpropogation will be handled by the library.

Now lets build a graphSAGE GNN out of these layers that takes in a DGLGraph we made previously.

In [10]:
class SimpleGraphSAGE(nn.Module):
    def __init__(
            self, 
            g, 
            n_features, 
            n_hidden, 
            n_classes, 
            n_layers
    ):
        super(SimpleGraphSAGE, self).__init__()
        # A ModuleList will hold all of our layers
        self.conv_layers = nn.ModuleList()
        self.g = g

        # input layer, the input size of which will be 
        # the number of features
        self.conv_layers.append(MeanAggSageLayer(n_features, n_hidden))
        # create the hidden layers: (n_layers - 1) allowing for the output layer
        for i in range(n_layers - 1):
            self.conv_layers.append(MeanAggSageLayer(n_hidden, n_hidden))
        # output layer, the output size of which will be the number of classes
        self.conv_layers.append(MeanAggSageLayer(n_hidden, n_classes))

    def forward(self, features):
        # h(0) will be equal to the feature matrix
        h = features
        for conv in self.conv_layers:
            # pass h through one layer and back into the next
            h = conv(self.g, h)
        # now we have h(k)
        return h

Before we create one of these models we need to decide on some params:

In [11]:
n_hidden = 16
n_layers = 2
learning_rate = 0.01
weight_decay = 0.0005
n_epochs = 120

Now we can create a GraphSAGE model using our graph (g)

In [12]:
model = SimpleGraphSAGE(g, n_features, n_hidden, n_classes, n_layers)
# we can send this to gpu memory as well
if gpu >= 0:
    model.cuda()
In [13]:
# use cross entropy loss function
loss_fcn = torch.nn.CrossEntropyLoss()

# use Adam Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                             weight_decay=weight_decay)
In [14]:
# we also need a scoring function, lets create a simple accuracy calculator:
def get_accuracy(pred, true):
    _, indices = torch.max(pred, dim=1)
    correct = torch.sum(indices == true)
    return correct.item() * 1.0 / len(true)

And we can decide on a simple training routine too.

In [15]:
# now our training pipeline is able to be built
def train(model, optimizer, n_epochs):
    # we will keep track how long each epoch takes so we can calculate things like
    # Traversed Edges Per Second (TEPS)
    dur = []
    all_train_logits = []

    for epoch in range(n_epochs):

        # This doesnt train the mdoel, instead it tells all the child modules
        # that the model is in training mode and not evaluating mode
        # (for examplee, when evaluating, you dont want to apply dropout to the input tensor)
        model.train()
        t0 = time.time()

        # the forward pass - sending the features to the model.forward method
        output = model(features)
        # calculate our current loss by comparing only the training nodes'
        # prediction and truth
        output_train = output[train_mask]
        loss = loss_fcn(output_train, labels[train_mask])

        # the backwards pass! update the weights in our SAGELayers - but first:
        # reset the gradient back to 0 before doing backpropogation
        # (pytorch by default accumulates the gradients after each backward pass)
        optimizer.zero_grad()
        # backpropogation
        loss.backward()
        # step the adam optimizer forward
        optimizer.step()

        dur.append(time.time() - t0)

        # set the model into evaluation model
        model.eval()
        # temporatily turn off the gradient calculation as
        # just want to simply inference
        with torch.no_grad():
            output_val = output[val_mask]
            labels_val = labels[val_mask]
            acc = get_accuracy(output_val, labels_val)

        # record the output logits for plotting later
        all_train_logits.append(output_train)

        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
              "TEPS {:.2f}".format(epoch, np.mean(dur), loss.item(), acc, 
                                   n_edges / np.mean(dur)))

    print('training complete')
    return model, output, all_train_logits
In [16]:
model, last_output, _ = train(model, optimizer, n_epochs)
Epoch 00000 | Time(s) 0.3513 | Loss 1.9933 | Accuracy 0.1033 | TEPS 30045.94
Epoch 00001 | Time(s) 0.1781 | Loss 1.9336 | Accuracy 0.2000 | TEPS 59266.31
Epoch 00002 | Time(s) 0.1206 | Loss 1.9098 | Accuracy 0.1567 | TEPS 87550.96
Epoch 00003 | Time(s) 0.0917 | Loss 1.8938 | Accuracy 0.1567 | TEPS 115125.11
Epoch 00004 | Time(s) 0.0742 | Loss 1.8687 | Accuracy 0.1567 | TEPS 142291.11
Epoch 00005 | Time(s) 0.0626 | Loss 1.8287 | Accuracy 0.1567 | TEPS 168603.25
Epoch 00006 | Time(s) 0.0545 | Loss 1.7855 | Accuracy 0.1567 | TEPS 193795.74
Epoch 00007 | Time(s) 0.0484 | Loss 1.7421 | Accuracy 0.1933 | TEPS 218315.54
Epoch 00008 | Time(s) 0.0435 | Loss 1.6985 | Accuracy 0.2467 | TEPS 242605.95
Epoch 00009 | Time(s) 0.0396 | Loss 1.6576 | Accuracy 0.2500 | TEPS 266485.18
Epoch 00010 | Time(s) 0.0365 | Loss 1.6097 | Accuracy 0.2500 | TEPS 289519.25
Epoch 00011 | Time(s) 0.0338 | Loss 1.5648 | Accuracy 0.2533 | TEPS 312007.22
Epoch 00012 | Time(s) 0.0315 | Loss 1.5282 | Accuracy 0.2800 | TEPS 334630.18
Epoch 00013 | Time(s) 0.0296 | Loss 1.4899 | Accuracy 0.3100 | TEPS 356399.88
Epoch 00014 | Time(s) 0.0280 | Loss 1.4539 | Accuracy 0.3300 | TEPS 376901.94
Epoch 00015 | Time(s) 0.0266 | Loss 1.4232 | Accuracy 0.3400 | TEPS 396565.48
Epoch 00016 | Time(s) 0.0254 | Loss 1.3881 | Accuracy 0.3567 | TEPS 416099.32
Epoch 00017 | Time(s) 0.0242 | Loss 1.3532 | Accuracy 0.4100 | TEPS 436313.89
Epoch 00018 | Time(s) 0.0232 | Loss 1.3218 | Accuracy 0.4200 | TEPS 455192.58
Epoch 00019 | Time(s) 0.0223 | Loss 1.2883 | Accuracy 0.4267 | TEPS 473721.71
Epoch 00020 | Time(s) 0.0214 | Loss 1.2546 | Accuracy 0.4300 | TEPS 492875.72
Epoch 00021 | Time(s) 0.0207 | Loss 1.2205 | Accuracy 0.4333 | TEPS 510958.06
Epoch 00022 | Time(s) 0.0200 | Loss 1.1838 | Accuracy 0.4367 | TEPS 527850.57
Epoch 00023 | Time(s) 0.0194 | Loss 1.1479 | Accuracy 0.4400 | TEPS 544357.84
Epoch 00024 | Time(s) 0.0188 | Loss 1.1103 | Accuracy 0.4400 | TEPS 561799.96
Epoch 00025 | Time(s) 0.0182 | Loss 1.0718 | Accuracy 0.4500 | TEPS 579010.06
Epoch 00026 | Time(s) 0.0178 | Loss 1.0334 | Accuracy 0.4533 | TEPS 594272.35
Epoch 00027 | Time(s) 0.0173 | Loss 0.9967 | Accuracy 0.4533 | TEPS 609420.49
Epoch 00028 | Time(s) 0.0169 | Loss 0.9610 | Accuracy 0.4533 | TEPS 625420.85
Epoch 00029 | Time(s) 0.0165 | Loss 0.9272 | Accuracy 0.4633 | TEPS 641027.94
Epoch 00030 | Time(s) 0.0161 | Loss 0.8965 | Accuracy 0.4700 | TEPS 655138.34
Epoch 00031 | Time(s) 0.0158 | Loss 0.8688 | Accuracy 0.4800 | TEPS 668699.48
Epoch 00032 | Time(s) 0.0155 | Loss 0.8445 | Accuracy 0.4800 | TEPS 682865.48
Epoch 00033 | Time(s) 0.0151 | Loss 0.8231 | Accuracy 0.4800 | TEPS 697541.19
Epoch 00034 | Time(s) 0.0148 | Loss 0.8045 | Accuracy 0.4833 | TEPS 711099.28
Epoch 00035 | Time(s) 0.0146 | Loss 0.7883 | Accuracy 0.4833 | TEPS 723552.35
Epoch 00036 | Time(s) 0.0143 | Loss 0.7744 | Accuracy 0.4833 | TEPS 735716.70
Epoch 00037 | Time(s) 0.0141 | Loss 0.7625 | Accuracy 0.4867 | TEPS 747927.98
Epoch 00038 | Time(s) 0.0139 | Loss 0.7521 | Accuracy 0.4900 | TEPS 760637.21
Epoch 00039 | Time(s) 0.0137 | Loss 0.7435 | Accuracy 0.4800 | TEPS 772127.98
Epoch 00040 | Time(s) 0.0135 | Loss 0.7362 | Accuracy 0.4800 | TEPS 784697.67
Epoch 00041 | Time(s) 0.0132 | Loss 0.7301 | Accuracy 0.4800 | TEPS 796934.35
Epoch 00042 | Time(s) 0.0131 | Loss 0.7246 | Accuracy 0.4800 | TEPS 806973.55
Epoch 00043 | Time(s) 0.0129 | Loss 0.7199 | Accuracy 0.4900 | TEPS 818121.36
Epoch 00044 | Time(s) 0.0127 | Loss 0.7160 | Accuracy 0.4833 | TEPS 829776.61
Epoch 00045 | Time(s) 0.0126 | Loss 0.7124 | Accuracy 0.4900 | TEPS 841024.78
Epoch 00046 | Time(s) 0.0124 | Loss 0.7093 | Accuracy 0.4867 | TEPS 850957.65
Epoch 00047 | Time(s) 0.0123 | Loss 0.7069 | Accuracy 0.4800 | TEPS 860931.90
Epoch 00048 | Time(s) 0.0121 | Loss 0.7048 | Accuracy 0.4833 | TEPS 871144.99
Epoch 00049 | Time(s) 0.0120 | Loss 0.7030 | Accuracy 0.4833 | TEPS 882138.04
Epoch 00050 | Time(s) 0.0118 | Loss 0.7015 | Accuracy 0.4833 | TEPS 891888.08
Epoch 00051 | Time(s) 0.0117 | Loss 0.7001 | Accuracy 0.4833 | TEPS 901174.62
Epoch 00052 | Time(s) 0.0116 | Loss 0.6993 | Accuracy 0.4833 | TEPS 909886.42
Epoch 00053 | Time(s) 0.0115 | Loss 0.6987 | Accuracy 0.4833 | TEPS 919983.82
Epoch 00054 | Time(s) 0.0113 | Loss 0.6981 | Accuracy 0.4867 | TEPS 930124.08
Epoch 00055 | Time(s) 0.0112 | Loss 0.6981 | Accuracy 0.4833 | TEPS 938704.55
Epoch 00056 | Time(s) 0.0111 | Loss 0.6978 | Accuracy 0.4867 | TEPS 947608.73
Epoch 00057 | Time(s) 0.0110 | Loss 0.6973 | Accuracy 0.4833 | TEPS 957398.81
Epoch 00058 | Time(s) 0.0109 | Loss 0.6971 | Accuracy 0.4800 | TEPS 966726.83
Epoch 00059 | Time(s) 0.0108 | Loss 0.6969 | Accuracy 0.4867 | TEPS 974567.15
Epoch 00060 | Time(s) 0.0107 | Loss 0.6972 | Accuracy 0.4867 | TEPS 982472.10
Epoch 00061 | Time(s) 0.0106 | Loss 0.6974 | Accuracy 0.4833 | TEPS 991218.42
Epoch 00062 | Time(s) 0.0106 | Loss 0.6977 | Accuracy 0.4800 | TEPS 1000208.91
Epoch 00063 | Time(s) 0.0105 | Loss 0.6980 | Accuracy 0.4833 | TEPS 1008088.84
Epoch 00064 | Time(s) 0.0104 | Loss 0.6977 | Accuracy 0.4867 | TEPS 1016003.16
Epoch 00065 | Time(s) 0.0103 | Loss 0.6979 | Accuracy 0.4867 | TEPS 1024735.36
Epoch 00066 | Time(s) 0.0102 | Loss 0.6984 | Accuracy 0.4867 | TEPS 1032537.23
Epoch 00067 | Time(s) 0.0102 | Loss 0.6981 | Accuracy 0.4867 | TEPS 1039217.64
Epoch 00068 | Time(s) 0.0101 | Loss 0.6985 | Accuracy 0.4867 | TEPS 1045401.18
Epoch 00069 | Time(s) 0.0100 | Loss 0.6987 | Accuracy 0.4867 | TEPS 1054151.91
Epoch 00070 | Time(s) 0.0099 | Loss 0.6989 | Accuracy 0.4833 | TEPS 1062318.55
Epoch 00071 | Time(s) 0.0099 | Loss 0.6995 | Accuracy 0.4867 | TEPS 1068795.06
Epoch 00072 | Time(s) 0.0098 | Loss 0.7000 | Accuracy 0.4933 | TEPS 1076709.42
Epoch 00073 | Time(s) 0.0097 | Loss 0.6990 | Accuracy 0.4833 | TEPS 1084722.11
Epoch 00074 | Time(s) 0.0097 | Loss 0.6988 | Accuracy 0.4833 | TEPS 1090677.00
Epoch 00075 | Time(s) 0.0096 | Loss 0.6994 | Accuracy 0.4833 | TEPS 1096742.78
Epoch 00076 | Time(s) 0.0096 | Loss 0.6982 | Accuracy 0.4833 | TEPS 1104532.20
Epoch 00077 | Time(s) 0.0095 | Loss 0.7010 | Accuracy 0.4867 | TEPS 1111003.21
Epoch 00078 | Time(s) 0.0094 | Loss 0.6982 | Accuracy 0.4833 | TEPS 1117684.05
Epoch 00079 | Time(s) 0.0094 | Loss 0.7022 | Accuracy 0.4833 | TEPS 1125716.75
Epoch 00080 | Time(s) 0.0093 | Loss 0.7010 | Accuracy 0.4867 | TEPS 1132980.49
Epoch 00081 | Time(s) 0.0093 | Loss 0.7199 | Accuracy 0.4833 | TEPS 1138581.82
Epoch 00082 | Time(s) 0.0092 | Loss 0.7025 | Accuracy 0.4800 | TEPS 1144306.68
Epoch 00083 | Time(s) 0.0092 | Loss 0.7514 | Accuracy 0.4567 | TEPS 1150634.14
Epoch 00084 | Time(s) 0.0091 | Loss 0.7068 | Accuracy 0.4767 | TEPS 1158067.50
Epoch 00085 | Time(s) 0.0091 | Loss 0.7117 | Accuracy 0.4767 | TEPS 1163952.10
Epoch 00086 | Time(s) 0.0090 | Loss 0.7314 | Accuracy 0.4700 | TEPS 1169796.76
Epoch 00087 | Time(s) 0.0090 | Loss 0.7189 | Accuracy 0.4667 | TEPS 1176649.11
Epoch 00088 | Time(s) 0.0089 | Loss 0.7071 | Accuracy 0.4800 | TEPS 1182802.09
Epoch 00089 | Time(s) 0.0089 | Loss 0.7029 | Accuracy 0.4800 | TEPS 1187737.32
Epoch 00090 | Time(s) 0.0089 | Loss 0.7044 | Accuracy 0.4833 | TEPS 1192385.26
Epoch 00091 | Time(s) 0.0088 | Loss 0.7053 | Accuracy 0.4733 | TEPS 1197513.64
Epoch 00092 | Time(s) 0.0088 | Loss 0.7024 | Accuracy 0.4733 | TEPS 1204499.93
Epoch 00093 | Time(s) 0.0087 | Loss 0.6977 | Accuracy 0.4800 | TEPS 1210229.46
Epoch 00094 | Time(s) 0.0087 | Loss 0.6950 | Accuracy 0.4733 | TEPS 1215021.05
Epoch 00095 | Time(s) 0.0087 | Loss 0.6946 | Accuracy 0.4800 | TEPS 1219811.69
Epoch 00096 | Time(s) 0.0086 | Loss 0.6936 | Accuracy 0.4833 | TEPS 1225802.76
Epoch 00097 | Time(s) 0.0086 | Loss 0.6953 | Accuracy 0.4867 | TEPS 1231733.78
Epoch 00098 | Time(s) 0.0085 | Loss 0.6934 | Accuracy 0.4833 | TEPS 1236817.64
Epoch 00099 | Time(s) 0.0085 | Loss 0.6938 | Accuracy 0.4833 | TEPS 1241463.93
Epoch 00100 | Time(s) 0.0085 | Loss 0.6915 | Accuracy 0.4833 | TEPS 1247359.45
Epoch 00101 | Time(s) 0.0084 | Loss 0.6908 | Accuracy 0.4833 | TEPS 1252913.05
Epoch 00102 | Time(s) 0.0084 | Loss 0.6914 | Accuracy 0.4800 | TEPS 1257252.11
Epoch 00103 | Time(s) 0.0084 | Loss 0.6922 | Accuracy 0.4733 | TEPS 1261581.95
Epoch 00104 | Time(s) 0.0083 | Loss 0.6925 | Accuracy 0.4767 | TEPS 1266734.50
Epoch 00105 | Time(s) 0.0083 | Loss 0.6927 | Accuracy 0.4767 | TEPS 1272757.33
Epoch 00106 | Time(s) 0.0083 | Loss 0.6922 | Accuracy 0.4767 | TEPS 1277207.69
Epoch 00107 | Time(s) 0.0082 | Loss 0.6920 | Accuracy 0.4767 | TEPS 1281911.79
Epoch 00108 | Time(s) 0.0082 | Loss 0.6918 | Accuracy 0.4767 | TEPS 1287474.68
Epoch 00109 | Time(s) 0.0082 | Loss 0.6917 | Accuracy 0.4833 | TEPS 1292312.85
Epoch 00110 | Time(s) 0.0081 | Loss 0.6925 | Accuracy 0.4867 | TEPS 1295865.36
Epoch 00111 | Time(s) 0.0081 | Loss 0.6928 | Accuracy 0.4867 | TEPS 1299774.71
Epoch 00112 | Time(s) 0.0081 | Loss 0.6920 | Accuracy 0.4833 | TEPS 1304053.80
Epoch 00113 | Time(s) 0.0081 | Loss 0.6922 | Accuracy 0.4833 | TEPS 1309556.55
Epoch 00114 | Time(s) 0.0080 | Loss 0.6922 | Accuracy 0.4767 | TEPS 1314332.16
Epoch 00115 | Time(s) 0.0080 | Loss 0.6924 | Accuracy 0.4767 | TEPS 1318035.82
Epoch 00116 | Time(s) 0.0080 | Loss 0.6924 | Accuracy 0.4767 | TEPS 1321612.28
Epoch 00117 | Time(s) 0.0080 | Loss 0.6920 | Accuracy 0.4767 | TEPS 1326760.37
Epoch 00118 | Time(s) 0.0079 | Loss 0.6921 | Accuracy 0.4833 | TEPS 1332249.50
Epoch 00119 | Time(s) 0.0079 | Loss 0.6923 | Accuracy 0.4833 | TEPS 1337244.49
training complete
In [17]:
# now we can evaluate the model on the test set
output_test = last_output[test_mask]
labels_test = labels[test_mask]
acc = get_accuracy(output_test, labels_test)
print("Test Accuracy {:.4f}".format(acc))
Test Accuracy 0.5370

Further experiments

Ok so not too impressive - how can we improve the model?

For one, there are other aggregation methods used in the original paper. DGL has implemented a SAGEConv layer that takes our simplified SageLayer further:

In [18]:
from dgl.nn.pytorch.conv.sageconv import SAGEConv
In [19]:
# A new graphSAGE net could be built as follows:
class GraphSAGE(nn.Module):
    """
    GraphSAGE pytorch implementation from paper `Inductive Representation Learning on
    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.
    """
    def __init__(
            self,
            g,
            n_features,
            n_hidden,
            n_classes,
            n_layers,
            agg,
            activation,
            dropout,
    ):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.g = g

        # input layer
        self.layers.append(
            SAGEConv(n_features, n_hidden, agg, feat_drop=dropout, activation=activation)
        )
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(
                SAGEConv(n_hidden, n_hidden, agg, feat_drop=dropout, activation=activation)
            )
        # output layer
        self.layers.append(
            SAGEConv(n_hidden, n_classes, agg, feat_drop=dropout, activation=None)
        ) # no activation None for final layer

    def forward(self, features):
        h = features
        for layer in self.layers:
            h = layer(self.g, h)
        return h

The 'agg' variable can now be one of ['mean', 'gcn', 'pool', 'lstm']. Additionally, a dropout fraction can be set, activation can be changed from 'relu', and the SAGEConv layer also supports an optional normalization function.

We'll start by looking at the mean aggregation function:

In [20]:
# lets try our same params as before but using a dropout value of 0.5
model = GraphSAGE(g, n_features, n_hidden, n_classes, n_layers, 'mean', F.relu, 0.5)
if gpu >= 0:
    model.cuda()
# initialize the optimzier again as the model params have changed
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
In [21]:
model, last_output, all_train_logits = train(model, optimizer, n_epochs)
Epoch 00000 | Time(s) 0.0071 | Loss 2.0558 | Accuracy 0.1433 | TEPS 1485690.85
Epoch 00001 | Time(s) 0.0065 | Loss 1.9424 | Accuracy 0.1767 | TEPS 1614699.96
Epoch 00002 | Time(s) 0.0064 | Loss 1.8707 | Accuracy 0.2133 | TEPS 1659319.65
Epoch 00003 | Time(s) 0.0065 | Loss 1.7868 | Accuracy 0.2133 | TEPS 1632712.20
Epoch 00004 | Time(s) 0.0062 | Loss 1.7963 | Accuracy 0.2833 | TEPS 1711471.11
Epoch 00005 | Time(s) 0.0059 | Loss 1.7359 | Accuracy 0.3267 | TEPS 1782386.43
Epoch 00006 | Time(s) 0.0059 | Loss 1.7680 | Accuracy 0.3333 | TEPS 1796627.97
Epoch 00007 | Time(s) 0.0058 | Loss 1.7066 | Accuracy 0.3167 | TEPS 1809527.77
Epoch 00008 | Time(s) 0.0057 | Loss 1.7269 | Accuracy 0.3900 | TEPS 1851153.77
Epoch 00009 | Time(s) 0.0056 | Loss 1.6374 | Accuracy 0.3833 | TEPS 1885481.84
Epoch 00010 | Time(s) 0.0056 | Loss 1.5783 | Accuracy 0.4200 | TEPS 1883863.47
Epoch 00011 | Time(s) 0.0056 | Loss 1.5715 | Accuracy 0.4133 | TEPS 1884520.13
Epoch 00012 | Time(s) 0.0055 | Loss 1.5080 | Accuracy 0.4400 | TEPS 1911221.92
Epoch 00013 | Time(s) 0.0054 | Loss 1.5281 | Accuracy 0.3933 | TEPS 1945088.33
Epoch 00014 | Time(s) 0.0054 | Loss 1.4862 | Accuracy 0.3767 | TEPS 1940981.11
Epoch 00015 | Time(s) 0.0054 | Loss 1.4082 | Accuracy 0.3900 | TEPS 1938939.30
Epoch 00016 | Time(s) 0.0054 | Loss 1.3887 | Accuracy 0.3933 | TEPS 1965422.52
Epoch 00017 | Time(s) 0.0053 | Loss 1.3930 | Accuracy 0.3900 | TEPS 1980633.13
Epoch 00018 | Time(s) 0.0053 | Loss 1.3569 | Accuracy 0.3967 | TEPS 1974997.27
Epoch 00019 | Time(s) 0.0054 | Loss 1.3233 | Accuracy 0.4367 | TEPS 1963879.45
Epoch 00020 | Time(s) 0.0054 | Loss 1.2675 | Accuracy 0.4833 | TEPS 1967131.56
Epoch 00021 | Time(s) 0.0053 | Loss 1.2011 | Accuracy 0.4500 | TEPS 1986206.60
Epoch 00022 | Time(s) 0.0053 | Loss 1.1671 | Accuracy 0.5367 | TEPS 1981558.12
Epoch 00023 | Time(s) 0.0054 | Loss 1.2142 | Accuracy 0.4867 | TEPS 1972452.81
Epoch 00024 | Time(s) 0.0053 | Loss 1.1579 | Accuracy 0.4800 | TEPS 1989036.31
Epoch 00025 | Time(s) 0.0053 | Loss 1.1982 | Accuracy 0.5000 | TEPS 1998215.38
Epoch 00026 | Time(s) 0.0053 | Loss 1.1417 | Accuracy 0.5000 | TEPS 1992697.12
Epoch 00027 | Time(s) 0.0053 | Loss 1.1224 | Accuracy 0.5500 | TEPS 1987756.40
Epoch 00028 | Time(s) 0.0053 | Loss 1.0773 | Accuracy 0.5033 | TEPS 1990825.79
Epoch 00029 | Time(s) 0.0053 | Loss 1.0476 | Accuracy 0.5467 | TEPS 1993193.50
Epoch 00030 | Time(s) 0.0053 | Loss 0.9734 | Accuracy 0.5633 | TEPS 1991487.58
Epoch 00031 | Time(s) 0.0053 | Loss 0.9620 | Accuracy 0.5367 | TEPS 1992633.59
Epoch 00032 | Time(s) 0.0053 | Loss 0.9791 | Accuracy 0.5233 | TEPS 2005071.30
Epoch 00033 | Time(s) 0.0052 | Loss 0.9412 | Accuracy 0.5000 | TEPS 2011486.82
Epoch 00034 | Time(s) 0.0053 | Loss 0.9722 | Accuracy 0.5367 | TEPS 2007271.38
Epoch 00035 | Time(s) 0.0053 | Loss 0.9740 | Accuracy 0.5333 | TEPS 2001153.34
Epoch 00036 | Time(s) 0.0053 | Loss 0.9601 | Accuracy 0.5200 | TEPS 2000620.03
Epoch 00037 | Time(s) 0.0053 | Loss 0.8868 | Accuracy 0.5933 | TEPS 2008524.71
Epoch 00038 | Time(s) 0.0053 | Loss 0.8787 | Accuracy 0.5233 | TEPS 2006092.24
Epoch 00039 | Time(s) 0.0053 | Loss 0.9439 | Accuracy 0.5567 | TEPS 2006433.86
Epoch 00040 | Time(s) 0.0052 | Loss 0.7985 | Accuracy 0.6033 | TEPS 2014810.75
Epoch 00041 | Time(s) 0.0052 | Loss 0.7447 | Accuracy 0.5767 | TEPS 2020241.02
Epoch 00042 | Time(s) 0.0052 | Loss 0.8268 | Accuracy 0.5733 | TEPS 2016705.14
Epoch 00043 | Time(s) 0.0052 | Loss 0.6557 | Accuracy 0.5367 | TEPS 2011094.72
Epoch 00044 | Time(s) 0.0052 | Loss 0.7830 | Accuracy 0.5867 | TEPS 2013955.81
Epoch 00045 | Time(s) 0.0052 | Loss 0.7333 | Accuracy 0.5700 | TEPS 2019263.46
Epoch 00046 | Time(s) 0.0052 | Loss 0.7626 | Accuracy 0.6400 | TEPS 2019064.30
Epoch 00047 | Time(s) 0.0052 | Loss 0.7696 | Accuracy 0.5733 | TEPS 2015813.38
Epoch 00048 | Time(s) 0.0052 | Loss 0.7309 | Accuracy 0.5967 | TEPS 2021871.83
Epoch 00049 | Time(s) 0.0052 | Loss 0.6787 | Accuracy 0.5900 | TEPS 2026528.74
Epoch 00050 | Time(s) 0.0052 | Loss 0.7000 | Accuracy 0.6500 | TEPS 2023315.98
Epoch 00051 | Time(s) 0.0052 | Loss 0.6309 | Accuracy 0.6233 | TEPS 2020381.75
Epoch 00052 | Time(s) 0.0052 | Loss 0.6739 | Accuracy 0.6267 | TEPS 2022490.89
Epoch 00053 | Time(s) 0.0052 | Loss 0.6990 | Accuracy 0.5833 | TEPS 2025536.32
Epoch 00054 | Time(s) 0.0052 | Loss 0.6132 | Accuracy 0.5933 | TEPS 2022572.72
Epoch 00055 | Time(s) 0.0052 | Loss 0.6039 | Accuracy 0.6433 | TEPS 2027576.84
Epoch 00056 | Time(s) 0.0052 | Loss 0.5973 | Accuracy 0.6100 | TEPS 2033722.80
Epoch 00057 | Time(s) 0.0052 | Loss 0.5718 | Accuracy 0.6533 | TEPS 2031242.97
Epoch 00058 | Time(s) 0.0052 | Loss 0.5204 | Accuracy 0.6467 | TEPS 2031224.07
Epoch 00059 | Time(s) 0.0052 | Loss 0.5022 | Accuracy 0.6333 | TEPS 2037757.62
Epoch 00060 | Time(s) 0.0052 | Loss 0.4410 | Accuracy 0.6433 | TEPS 2039056.64
Epoch 00061 | Time(s) 0.0052 | Loss 0.5277 | Accuracy 0.6467 | TEPS 2036123.24
Epoch 00062 | Time(s) 0.0052 | Loss 0.4747 | Accuracy 0.6667 | TEPS 2032122.24
Epoch 00063 | Time(s) 0.0052 | Loss 0.5103 | Accuracy 0.7000 | TEPS 2033815.04
Epoch 00064 | Time(s) 0.0052 | Loss 0.5389 | Accuracy 0.6100 | TEPS 2036488.31
Epoch 00065 | Time(s) 0.0052 | Loss 0.4487 | Accuracy 0.6367 | TEPS 2032943.36
Epoch 00066 | Time(s) 0.0052 | Loss 0.4255 | Accuracy 0.6500 | TEPS 2034908.02
Epoch 00067 | Time(s) 0.0052 | Loss 0.4118 | Accuracy 0.6300 | TEPS 2038019.43
Epoch 00068 | Time(s) 0.0052 | Loss 0.4408 | Accuracy 0.6500 | TEPS 2033657.10
Epoch 00069 | Time(s) 0.0052 | Loss 0.4112 | Accuracy 0.6333 | TEPS 2033645.31
Epoch 00070 | Time(s) 0.0052 | Loss 0.4387 | Accuracy 0.6800 | TEPS 2037884.24
Epoch 00071 | Time(s) 0.0052 | Loss 0.4357 | Accuracy 0.6800 | TEPS 2038460.23
Epoch 00072 | Time(s) 0.0052 | Loss 0.4357 | Accuracy 0.6600 | TEPS 2034374.75
Epoch 00073 | Time(s) 0.0052 | Loss 0.3508 | Accuracy 0.6900 | TEPS 2031178.18
Epoch 00074 | Time(s) 0.0052 | Loss 0.3400 | Accuracy 0.6500 | TEPS 2035663.00
Epoch 00075 | Time(s) 0.0052 | Loss 0.3004 | Accuracy 0.6967 | TEPS 2040403.89
Epoch 00076 | Time(s) 0.0052 | Loss 0.3145 | Accuracy 0.6733 | TEPS 2038124.25
Epoch 00077 | Time(s) 0.0052 | Loss 0.3343 | Accuracy 0.7067 | TEPS 2038015.33
Epoch 00078 | Time(s) 0.0052 | Loss 0.3228 | Accuracy 0.7000 | TEPS 2041832.64
Epoch 00079 | Time(s) 0.0052 | Loss 0.3436 | Accuracy 0.6533 | TEPS 2042826.53
Epoch 00080 | Time(s) 0.0052 | Loss 0.3581 | Accuracy 0.6800 | TEPS 2038915.40
Epoch 00081 | Time(s) 0.0052 | Loss 0.2692 | Accuracy 0.6800 | TEPS 2036916.97
Epoch 00082 | Time(s) 0.0052 | Loss 0.3473 | Accuracy 0.6833 | TEPS 2041090.89
Epoch 00083 | Time(s) 0.0052 | Loss 0.3240 | Accuracy 0.6833 | TEPS 2045492.42
Epoch 00084 | Time(s) 0.0052 | Loss 0.3359 | Accuracy 0.6967 | TEPS 2044508.07
Epoch 00085 | Time(s) 0.0052 | Loss 0.2835 | Accuracy 0.7100 | TEPS 2042885.30
Epoch 00086 | Time(s) 0.0052 | Loss 0.3292 | Accuracy 0.7133 | TEPS 2045922.19
Epoch 00087 | Time(s) 0.0052 | Loss 0.3080 | Accuracy 0.7000 | TEPS 2048079.17
Epoch 00088 | Time(s) 0.0052 | Loss 0.2845 | Accuracy 0.6733 | TEPS 2045930.99
Epoch 00089 | Time(s) 0.0052 | Loss 0.2372 | Accuracy 0.7333 | TEPS 2043894.65
Epoch 00090 | Time(s) 0.0052 | Loss 0.3151 | Accuracy 0.6933 | TEPS 2044842.89
Epoch 00091 | Time(s) 0.0052 | Loss 0.3093 | Accuracy 0.7000 | TEPS 2047652.32
Epoch 00092 | Time(s) 0.0052 | Loss 0.3189 | Accuracy 0.6633 | TEPS 2047459.23
Epoch 00093 | Time(s) 0.0052 | Loss 0.3020 | Accuracy 0.6900 | TEPS 2047375.03
Epoch 00094 | Time(s) 0.0051 | Loss 0.2995 | Accuracy 0.7367 | TEPS 2051469.46
Epoch 00095 | Time(s) 0.0051 | Loss 0.2763 | Accuracy 0.7333 | TEPS 2053465.68
Epoch 00096 | Time(s) 0.0051 | Loss 0.3061 | Accuracy 0.6900 | TEPS 2051599.07
Epoch 00097 | Time(s) 0.0052 | Loss 0.2243 | Accuracy 0.7067 | TEPS 2048679.23
Epoch 00098 | Time(s) 0.0052 | Loss 0.2637 | Accuracy 0.6967 | TEPS 2047927.39
Epoch 00099 | Time(s) 0.0051 | Loss 0.2535 | Accuracy 0.6900 | TEPS 2050349.03
Epoch 00100 | Time(s) 0.0052 | Loss 0.2586 | Accuracy 0.7400 | TEPS 2048634.71
Epoch 00101 | Time(s) 0.0051 | Loss 0.2387 | Accuracy 0.7200 | TEPS 2049901.23
Epoch 00102 | Time(s) 0.0051 | Loss 0.1832 | Accuracy 0.7100 | TEPS 2053411.15
Epoch 00103 | Time(s) 0.0051 | Loss 0.2447 | Accuracy 0.7000 | TEPS 2051598.46
Epoch 00104 | Time(s) 0.0051 | Loss 0.2882 | Accuracy 0.7233 | TEPS 2051313.10
Epoch 00105 | Time(s) 0.0051 | Loss 0.1861 | Accuracy 0.7133 | TEPS 2054903.77
Epoch 00106 | Time(s) 0.0051 | Loss 0.1810 | Accuracy 0.7233 | TEPS 2056789.83
Epoch 00107 | Time(s) 0.0051 | Loss 0.2229 | Accuracy 0.6700 | TEPS 2054705.20
Epoch 00108 | Time(s) 0.0051 | Loss 0.2354 | Accuracy 0.7333 | TEPS 2052640.23
Epoch 00109 | Time(s) 0.0051 | Loss 0.2513 | Accuracy 0.7133 | TEPS 2055016.69
Epoch 00110 | Time(s) 0.0051 | Loss 0.2004 | Accuracy 0.7167 | TEPS 2056941.51
Epoch 00111 | Time(s) 0.0051 | Loss 0.2915 | Accuracy 0.6867 | TEPS 2053153.08
Epoch 00112 | Time(s) 0.0051 | Loss 0.2392 | Accuracy 0.7200 | TEPS 2053672.60
Epoch 00113 | Time(s) 0.0051 | Loss 0.1653 | Accuracy 0.7333 | TEPS 2053785.39
Epoch 00114 | Time(s) 0.0052 | Loss 0.2616 | Accuracy 0.6900 | TEPS 2048771.43
Epoch 00115 | Time(s) 0.0052 | Loss 0.2607 | Accuracy 0.7233 | TEPS 2041532.76
Epoch 00116 | Time(s) 0.0052 | Loss 0.1716 | Accuracy 0.7133 | TEPS 2043677.74
Epoch 00117 | Time(s) 0.0052 | Loss 0.2119 | Accuracy 0.7033 | TEPS 2046710.02
Epoch 00118 | Time(s) 0.0052 | Loss 0.1964 | Accuracy 0.7400 | TEPS 2046972.68
Epoch 00119 | Time(s) 0.0052 | Loss 0.2008 | Accuracy 0.7267 | TEPS 2046362.87
training complete
In [22]:
acc = get_accuracy(last_output[test_mask], labels_test)
print("Test Accuracy {:.4f}".format(acc))
Test Accuracy 0.6740

Slightly better! Lets change the aggregation function. in the original GraphSAGE paper they found the LSTM and pool methods generally outperformed the mean and GCN aggreation across a range of datasets. Lets try the pool method (which refers to a max pool aggregator over the neighbourhood) and bump the number of hidden channels up.

In [23]:
model = GraphSAGE(g, n_features, 128, n_classes, 2, 'pool', F.relu, 0.3)
if gpu >= 0:
    model.cuda()
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.003, weight_decay=weight_decay
)
In [24]:
model, last_output, all_train_logits = train(model, optimizer, n_epochs)
Epoch 00000 | Time(s) 0.0083 | Loss 2.0174 | Accuracy 0.1200 | TEPS 1277595.53
Epoch 00001 | Time(s) 0.0077 | Loss 1.8457 | Accuracy 0.3467 | TEPS 1363841.64
Epoch 00002 | Time(s) 0.0075 | Loss 1.8008 | Accuracy 0.3100 | TEPS 1411052.77
Epoch 00003 | Time(s) 0.0075 | Loss 1.8041 | Accuracy 0.3067 | TEPS 1410079.08
Epoch 00004 | Time(s) 0.0074 | Loss 1.7084 | Accuracy 0.3400 | TEPS 1419664.38
Epoch 00005 | Time(s) 0.0074 | Loss 1.5917 | Accuracy 0.3233 | TEPS 1428089.96
Epoch 00006 | Time(s) 0.0074 | Loss 1.5353 | Accuracy 0.3567 | TEPS 1433506.37
Epoch 00007 | Time(s) 0.0073 | Loss 1.3963 | Accuracy 0.4200 | TEPS 1443754.34
Epoch 00008 | Time(s) 0.0073 | Loss 1.4551 | Accuracy 0.3500 | TEPS 1446125.91
Epoch 00009 | Time(s) 0.0073 | Loss 1.2598 | Accuracy 0.4767 | TEPS 1449954.09
Epoch 00010 | Time(s) 0.0072 | Loss 1.1592 | Accuracy 0.5000 | TEPS 1457449.82
Epoch 00011 | Time(s) 0.0072 | Loss 1.0904 | Accuracy 0.5133 | TEPS 1459168.82
Epoch 00012 | Time(s) 0.0072 | Loss 0.9747 | Accuracy 0.5900 | TEPS 1459211.98
Epoch 00013 | Time(s) 0.0072 | Loss 0.9672 | Accuracy 0.5600 | TEPS 1464904.79
Epoch 00014 | Time(s) 0.0072 | Loss 0.8351 | Accuracy 0.6167 | TEPS 1466489.34
Epoch 00015 | Time(s) 0.0072 | Loss 0.7824 | Accuracy 0.6133 | TEPS 1468128.09
Epoch 00016 | Time(s) 0.0072 | Loss 0.6972 | Accuracy 0.6333 | TEPS 1472284.96
Epoch 00017 | Time(s) 0.0072 | Loss 0.6629 | Accuracy 0.6200 | TEPS 1474576.92
Epoch 00018 | Time(s) 0.0071 | Loss 0.5557 | Accuracy 0.6467 | TEPS 1478691.98
Epoch 00019 | Time(s) 0.0071 | Loss 0.5078 | Accuracy 0.6533 | TEPS 1483125.33
Epoch 00020 | Time(s) 0.0071 | Loss 0.5963 | Accuracy 0.6667 | TEPS 1490129.18
Epoch 00021 | Time(s) 0.0071 | Loss 0.4609 | Accuracy 0.7033 | TEPS 1496121.82
Epoch 00022 | Time(s) 0.0070 | Loss 0.4373 | Accuracy 0.6900 | TEPS 1502140.64
Epoch 00023 | Time(s) 0.0070 | Loss 0.3905 | Accuracy 0.7067 | TEPS 1507488.85
Epoch 00024 | Time(s) 0.0070 | Loss 0.2703 | Accuracy 0.6967 | TEPS 1510438.92
Epoch 00025 | Time(s) 0.0070 | Loss 0.4080 | Accuracy 0.6500 | TEPS 1513104.70
Epoch 00026 | Time(s) 0.0070 | Loss 0.3532 | Accuracy 0.6267 | TEPS 1512957.39
Epoch 00027 | Time(s) 0.0070 | Loss 0.2651 | Accuracy 0.7067 | TEPS 1514939.23
Epoch 00028 | Time(s) 0.0069 | Loss 0.2107 | Accuracy 0.7000 | TEPS 1519434.81
Epoch 00029 | Time(s) 0.0069 | Loss 0.2782 | Accuracy 0.6900 | TEPS 1523602.40
Epoch 00030 | Time(s) 0.0069 | Loss 0.2956 | Accuracy 0.7233 | TEPS 1523946.60
Epoch 00031 | Time(s) 0.0069 | Loss 0.1970 | Accuracy 0.7567 | TEPS 1522654.21
Epoch 00032 | Time(s) 0.0069 | Loss 0.1694 | Accuracy 0.7300 | TEPS 1524356.50
Epoch 00033 | Time(s) 0.0069 | Loss 0.1773 | Accuracy 0.7167 | TEPS 1527800.45
Epoch 00034 | Time(s) 0.0069 | Loss 0.1943 | Accuracy 0.7300 | TEPS 1529716.76
Epoch 00035 | Time(s) 0.0069 | Loss 0.2696 | Accuracy 0.6700 | TEPS 1531241.18
Epoch 00036 | Time(s) 0.0069 | Loss 0.1375 | Accuracy 0.6967 | TEPS 1532604.26
Epoch 00037 | Time(s) 0.0069 | Loss 0.2249 | Accuracy 0.7033 | TEPS 1534064.27
Epoch 00038 | Time(s) 0.0069 | Loss 0.2249 | Accuracy 0.7167 | TEPS 1537038.81
Epoch 00039 | Time(s) 0.0069 | Loss 0.1069 | Accuracy 0.7333 | TEPS 1538318.42
Epoch 00040 | Time(s) 0.0069 | Loss 0.2048 | Accuracy 0.7133 | TEPS 1540869.22
Epoch 00041 | Time(s) 0.0068 | Loss 0.1598 | Accuracy 0.7167 | TEPS 1542099.54
Epoch 00042 | Time(s) 0.0068 | Loss 0.0943 | Accuracy 0.7400 | TEPS 1542699.22
Epoch 00043 | Time(s) 0.0068 | Loss 0.1234 | Accuracy 0.7467 | TEPS 1545011.31
Epoch 00044 | Time(s) 0.0068 | Loss 0.1193 | Accuracy 0.7067 | TEPS 1546090.10
Epoch 00045 | Time(s) 0.0068 | Loss 0.1316 | Accuracy 0.7167 | TEPS 1546322.29
Epoch 00046 | Time(s) 0.0068 | Loss 0.0913 | Accuracy 0.7100 | TEPS 1547159.84
Epoch 00047 | Time(s) 0.0068 | Loss 0.0649 | Accuracy 0.7467 | TEPS 1547691.65
Epoch 00048 | Time(s) 0.0068 | Loss 0.1130 | Accuracy 0.7000 | TEPS 1549759.29
Epoch 00049 | Time(s) 0.0068 | Loss 0.0958 | Accuracy 0.7433 | TEPS 1551775.52
Epoch 00050 | Time(s) 0.0068 | Loss 0.0862 | Accuracy 0.7633 | TEPS 1552700.53
Epoch 00051 | Time(s) 0.0068 | Loss 0.0618 | Accuracy 0.7267 | TEPS 1554136.34
Epoch 00052 | Time(s) 0.0068 | Loss 0.0613 | Accuracy 0.7333 | TEPS 1549626.24
Epoch 00053 | Time(s) 0.0068 | Loss 0.0389 | Accuracy 0.7733 | TEPS 1550189.36
Epoch 00054 | Time(s) 0.0068 | Loss 0.0435 | Accuracy 0.7367 | TEPS 1551034.63
Epoch 00055 | Time(s) 0.0068 | Loss 0.0588 | Accuracy 0.7267 | TEPS 1552035.15
Epoch 00056 | Time(s) 0.0068 | Loss 0.0586 | Accuracy 0.7600 | TEPS 1553865.25
Epoch 00057 | Time(s) 0.0068 | Loss 0.0665 | Accuracy 0.7600 | TEPS 1555649.53
Epoch 00058 | Time(s) 0.0068 | Loss 0.0516 | Accuracy 0.7333 | TEPS 1556432.60
Epoch 00059 | Time(s) 0.0068 | Loss 0.0443 | Accuracy 0.7733 | TEPS 1558149.33
Epoch 00060 | Time(s) 0.0068 | Loss 0.0415 | Accuracy 0.7500 | TEPS 1558042.52
Epoch 00061 | Time(s) 0.0068 | Loss 0.0333 | Accuracy 0.7367 | TEPS 1559531.46
Epoch 00062 | Time(s) 0.0068 | Loss 0.0405 | Accuracy 0.7333 | TEPS 1557133.81
Epoch 00063 | Time(s) 0.0068 | Loss 0.0613 | Accuracy 0.7533 | TEPS 1553444.05
Epoch 00064 | Time(s) 0.0068 | Loss 0.0760 | Accuracy 0.7200 | TEPS 1549478.09
Epoch 00065 | Time(s) 0.0068 | Loss 0.0517 | Accuracy 0.7333 | TEPS 1544782.21
Epoch 00066 | Time(s) 0.0068 | Loss 0.0318 | Accuracy 0.7700 | TEPS 1543405.44
Epoch 00067 | Time(s) 0.0068 | Loss 0.0835 | Accuracy 0.7400 | TEPS 1545022.48
Epoch 00068 | Time(s) 0.0068 | Loss 0.0673 | Accuracy 0.7467 | TEPS 1545586.52
Epoch 00069 | Time(s) 0.0068 | Loss 0.0478 | Accuracy 0.7533 | TEPS 1546813.13
Epoch 00070 | Time(s) 0.0068 | Loss 0.0295 | Accuracy 0.7200 | TEPS 1546736.57
Epoch 00071 | Time(s) 0.0068 | Loss 0.0252 | Accuracy 0.7600 | TEPS 1547457.25
Epoch 00072 | Time(s) 0.0068 | Loss 0.0511 | Accuracy 0.7500 | TEPS 1548935.63
Epoch 00073 | Time(s) 0.0068 | Loss 0.0610 | Accuracy 0.7133 | TEPS 1550370.90
Epoch 00074 | Time(s) 0.0068 | Loss 0.0364 | Accuracy 0.7800 | TEPS 1550044.31
Epoch 00075 | Time(s) 0.0068 | Loss 0.0422 | Accuracy 0.7767 | TEPS 1551042.98
Epoch 00076 | Time(s) 0.0068 | Loss 0.0366 | Accuracy 0.7700 | TEPS 1552378.09
Epoch 00077 | Time(s) 0.0068 | Loss 0.0445 | Accuracy 0.7200 | TEPS 1553729.40
Epoch 00078 | Time(s) 0.0068 | Loss 0.0869 | Accuracy 0.7400 | TEPS 1554325.95
Epoch 00079 | Time(s) 0.0068 | Loss 0.0439 | Accuracy 0.7200 | TEPS 1555273.98
Epoch 00080 | Time(s) 0.0068 | Loss 0.0299 | Accuracy 0.7467 | TEPS 1556530.67
Epoch 00081 | Time(s) 0.0068 | Loss 0.0507 | Accuracy 0.7200 | TEPS 1556992.41
Epoch 00082 | Time(s) 0.0068 | Loss 0.0503 | Accuracy 0.7367 | TEPS 1558052.77
Epoch 00083 | Time(s) 0.0068 | Loss 0.0580 | Accuracy 0.7267 | TEPS 1558402.01
Epoch 00084 | Time(s) 0.0068 | Loss 0.0259 | Accuracy 0.7333 | TEPS 1558393.98
Epoch 00085 | Time(s) 0.0068 | Loss 0.0509 | Accuracy 0.7333 | TEPS 1559469.90
Epoch 00086 | Time(s) 0.0068 | Loss 0.0318 | Accuracy 0.7567 | TEPS 1560302.54
Epoch 00087 | Time(s) 0.0068 | Loss 0.0304 | Accuracy 0.7500 | TEPS 1561123.37
Epoch 00088 | Time(s) 0.0068 | Loss 0.0642 | Accuracy 0.7167 | TEPS 1562201.52
Epoch 00089 | Time(s) 0.0068 | Loss 0.0318 | Accuracy 0.7800 | TEPS 1562586.52
Epoch 00090 | Time(s) 0.0068 | Loss 0.0551 | Accuracy 0.7333 | TEPS 1561342.43
Epoch 00091 | Time(s) 0.0068 | Loss 0.0542 | Accuracy 0.7700 | TEPS 1562048.99
Epoch 00092 | Time(s) 0.0068 | Loss 0.0466 | Accuracy 0.7400 | TEPS 1562458.11
Epoch 00093 | Time(s) 0.0068 | Loss 0.0428 | Accuracy 0.7433 | TEPS 1563365.97
Epoch 00094 | Time(s) 0.0068 | Loss 0.0498 | Accuracy 0.7200 | TEPS 1563762.57
Epoch 00095 | Time(s) 0.0067 | Loss 0.0599 | Accuracy 0.7233 | TEPS 1563947.37
Epoch 00096 | Time(s) 0.0067 | Loss 0.0257 | Accuracy 0.7533 | TEPS 1564707.38
Epoch 00097 | Time(s) 0.0067 | Loss 0.0316 | Accuracy 0.7600 | TEPS 1565639.01
Epoch 00098 | Time(s) 0.0067 | Loss 0.0279 | Accuracy 0.7467 | TEPS 1565969.72
Epoch 00099 | Time(s) 0.0067 | Loss 0.0638 | Accuracy 0.7467 | TEPS 1564806.49
Epoch 00100 | Time(s) 0.0067 | Loss 0.0312 | Accuracy 0.7467 | TEPS 1565428.93
Epoch 00101 | Time(s) 0.0067 | Loss 0.0293 | Accuracy 0.7333 | TEPS 1565812.69
Epoch 00102 | Time(s) 0.0067 | Loss 0.0301 | Accuracy 0.7733 | TEPS 1566651.89
Epoch 00103 | Time(s) 0.0067 | Loss 0.0432 | Accuracy 0.7300 | TEPS 1567098.14
Epoch 00104 | Time(s) 0.0067 | Loss 0.0353 | Accuracy 0.7600 | TEPS 1567668.81
Epoch 00105 | Time(s) 0.0067 | Loss 0.0233 | Accuracy 0.7600 | TEPS 1567902.72
Epoch 00106 | Time(s) 0.0067 | Loss 0.0879 | Accuracy 0.7633 | TEPS 1568157.76
Epoch 00107 | Time(s) 0.0067 | Loss 0.0411 | Accuracy 0.7600 | TEPS 1568410.73
Epoch 00108 | Time(s) 0.0067 | Loss 0.0168 | Accuracy 0.7567 | TEPS 1568954.41
Epoch 00109 | Time(s) 0.0067 | Loss 0.0782 | Accuracy 0.6733 | TEPS 1569735.44
Epoch 00110 | Time(s) 0.0067 | Loss 0.0278 | Accuracy 0.7200 | TEPS 1569820.40
Epoch 00111 | Time(s) 0.0067 | Loss 0.0180 | Accuracy 0.7767 | TEPS 1570480.59
Epoch 00112 | Time(s) 0.0067 | Loss 0.0247 | Accuracy 0.7467 | TEPS 1571129.15
Epoch 00113 | Time(s) 0.0067 | Loss 0.0303 | Accuracy 0.7567 | TEPS 1571824.61
Epoch 00114 | Time(s) 0.0067 | Loss 0.0298 | Accuracy 0.7400 | TEPS 1572496.93
Epoch 00115 | Time(s) 0.0067 | Loss 0.0273 | Accuracy 0.7633 | TEPS 1572730.42
Epoch 00116 | Time(s) 0.0067 | Loss 0.0304 | Accuracy 0.7433 | TEPS 1573065.08
Epoch 00117 | Time(s) 0.0067 | Loss 0.0233 | Accuracy 0.7500 | TEPS 1573275.28
Epoch 00118 | Time(s) 0.0067 | Loss 0.0305 | Accuracy 0.7567 | TEPS 1573445.35
Epoch 00119 | Time(s) 0.0067 | Loss 0.0266 | Accuracy 0.7433 | TEPS 1573915.17
training complete
In [25]:
acc = get_accuracy(last_output[test_mask], labels_test)
print("Test Accuracy {:.4f}".format(acc))
Test Accuracy 0.7440

And finally the LSTM aggregation function:

In [33]:
model = GraphSAGE(g, n_features, 128, n_classes, 2, 'lstm', F.relu, 0.1)
if gpu >= 0:
    model.cuda()
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.003, weight_decay=weight_decay
)
In [34]:
model, last_output, all_train_logits = train(model, optimizer, n_epochs)
Epoch 00000 | Time(s) 0.3460 | Loss 2.0057 | Accuracy 0.0967 | TEPS 30512.46
Epoch 00001 | Time(s) 0.3336 | Loss 1.8551 | Accuracy 0.3500 | TEPS 31641.23
Epoch 00002 | Time(s) 0.3284 | Loss 1.8039 | Accuracy 0.3033 | TEPS 32139.34
Epoch 00003 | Time(s) 0.3260 | Loss 1.7618 | Accuracy 0.3533 | TEPS 32382.68
Epoch 00004 | Time(s) 0.3243 | Loss 1.7365 | Accuracy 0.3767 | TEPS 32545.76
Epoch 00005 | Time(s) 0.3231 | Loss 1.6779 | Accuracy 0.3500 | TEPS 32667.21
Epoch 00006 | Time(s) 0.3223 | Loss 1.6484 | Accuracy 0.3200 | TEPS 32747.99
Epoch 00007 | Time(s) 0.3217 | Loss 1.7306 | Accuracy 0.3500 | TEPS 32813.20
Epoch 00008 | Time(s) 0.3212 | Loss 1.4687 | Accuracy 0.4300 | TEPS 32869.37
Epoch 00009 | Time(s) 0.3208 | Loss 1.4983 | Accuracy 0.4400 | TEPS 32903.87
Epoch 00010 | Time(s) 0.3235 | Loss 1.3894 | Accuracy 0.4600 | TEPS 32632.24
Epoch 00011 | Time(s) 0.3237 | Loss 1.3150 | Accuracy 0.4833 | TEPS 32610.46
Epoch 00012 | Time(s) 0.3237 | Loss 1.2075 | Accuracy 0.4900 | TEPS 32606.76
Epoch 00013 | Time(s) 0.3235 | Loss 0.9814 | Accuracy 0.5400 | TEPS 32632.49
Epoch 00014 | Time(s) 0.3232 | Loss 1.1288 | Accuracy 0.5367 | TEPS 32665.13
Epoch 00015 | Time(s) 0.3229 | Loss 0.9980 | Accuracy 0.5400 | TEPS 32688.15
Epoch 00016 | Time(s) 0.3226 | Loss 1.0572 | Accuracy 0.4833 | TEPS 32718.22
Epoch 00017 | Time(s) 0.3224 | Loss 0.9635 | Accuracy 0.5267 | TEPS 32738.01
Epoch 00018 | Time(s) 0.3223 | Loss 0.7166 | Accuracy 0.5767 | TEPS 32754.20
Epoch 00019 | Time(s) 0.3221 | Loss 0.8618 | Accuracy 0.5100 | TEPS 32769.26
Epoch 00020 | Time(s) 0.3220 | Loss 0.7159 | Accuracy 0.6033 | TEPS 32784.46
Epoch 00021 | Time(s) 0.3219 | Loss 0.7758 | Accuracy 0.5233 | TEPS 32797.16
Epoch 00022 | Time(s) 0.3217 | Loss 0.5731 | Accuracy 0.6167 | TEPS 32809.91
Epoch 00023 | Time(s) 0.3216 | Loss 0.5887 | Accuracy 0.6233 | TEPS 32820.79
Epoch 00024 | Time(s) 0.3215 | Loss 0.5071 | Accuracy 0.6500 | TEPS 32829.34
Epoch 00025 | Time(s) 0.3236 | Loss 0.3827 | Accuracy 0.6733 | TEPS 32622.11
Epoch 00026 | Time(s) 0.3239 | Loss 0.5605 | Accuracy 0.5500 | TEPS 32594.45
Epoch 00027 | Time(s) 0.3243 | Loss 0.3033 | Accuracy 0.6967 | TEPS 32550.84
Epoch 00028 | Time(s) 0.3242 | Loss 0.3426 | Accuracy 0.6700 | TEPS 32564.09
Epoch 00029 | Time(s) 0.3240 | Loss 0.2604 | Accuracy 0.6900 | TEPS 32577.30
Epoch 00030 | Time(s) 0.3239 | Loss 0.2260 | Accuracy 0.7033 | TEPS 32591.41
Epoch 00031 | Time(s) 0.3238 | Loss 0.2039 | Accuracy 0.6800 | TEPS 32602.75
Epoch 00032 | Time(s) 0.3237 | Loss 0.2324 | Accuracy 0.6600 | TEPS 32615.38
Epoch 00033 | Time(s) 0.3235 | Loss 0.1629 | Accuracy 0.7200 | TEPS 32626.08
Epoch 00034 | Time(s) 0.3235 | Loss 0.1364 | Accuracy 0.7800 | TEPS 32634.02
Epoch 00035 | Time(s) 0.3233 | Loss 0.2823 | Accuracy 0.6400 | TEPS 32646.31
Epoch 00036 | Time(s) 0.3233 | Loss 0.5731 | Accuracy 0.5733 | TEPS 32652.10
Epoch 00037 | Time(s) 0.3232 | Loss 0.4843 | Accuracy 0.5800 | TEPS 32659.59
Epoch 00038 | Time(s) 0.3231 | Loss 0.1954 | Accuracy 0.6867 | TEPS 32667.08
Epoch 00039 | Time(s) 0.3231 | Loss 0.4060 | Accuracy 0.6067 | TEPS 32670.03
Epoch 00040 | Time(s) 0.3230 | Loss 0.3230 | Accuracy 0.6167 | TEPS 32679.75
Epoch 00041 | Time(s) 0.3229 | Loss 0.2139 | Accuracy 0.6567 | TEPS 32688.73
Epoch 00042 | Time(s) 0.3224 | Loss 0.2751 | Accuracy 0.6600 | TEPS 32743.92
Epoch 00043 | Time(s) 0.3218 | Loss 0.2614 | Accuracy 0.6633 | TEPS 32800.25
Epoch 00044 | Time(s) 0.3213 | Loss 0.1765 | Accuracy 0.6867 | TEPS 32852.86
Epoch 00045 | Time(s) 0.3211 | Loss 0.1429 | Accuracy 0.6800 | TEPS 32872.09
Epoch 00046 | Time(s) 0.3209 | Loss 0.1490 | Accuracy 0.6900 | TEPS 32890.47
Epoch 00047 | Time(s) 0.3208 | Loss 0.1207 | Accuracy 0.6800 | TEPS 32907.71
Epoch 00048 | Time(s) 0.3198 | Loss 0.1251 | Accuracy 0.6867 | TEPS 33003.62
Epoch 00049 | Time(s) 0.3193 | Loss 0.0912 | Accuracy 0.7000 | TEPS 33058.08
Epoch 00050 | Time(s) 0.3186 | Loss 0.0794 | Accuracy 0.6867 | TEPS 33130.02
Epoch 00051 | Time(s) 0.3182 | Loss 0.0577 | Accuracy 0.7200 | TEPS 33175.04
Epoch 00052 | Time(s) 0.3176 | Loss 0.0601 | Accuracy 0.7033 | TEPS 33240.98
Epoch 00053 | Time(s) 0.3168 | Loss 0.0676 | Accuracy 0.6867 | TEPS 33322.11
Epoch 00054 | Time(s) 0.3160 | Loss 0.0545 | Accuracy 0.6933 | TEPS 33404.18
Epoch 00055 | Time(s) 0.3153 | Loss 0.0410 | Accuracy 0.7133 | TEPS 33480.80
Epoch 00056 | Time(s) 0.3146 | Loss 0.0293 | Accuracy 0.7000 | TEPS 33554.51
Epoch 00057 | Time(s) 0.3140 | Loss 0.0233 | Accuracy 0.7100 | TEPS 33622.87
Epoch 00058 | Time(s) 0.3133 | Loss 0.0223 | Accuracy 0.7200 | TEPS 33695.00
Epoch 00059 | Time(s) 0.3127 | Loss 0.0390 | Accuracy 0.6967 | TEPS 33762.99
Epoch 00060 | Time(s) 0.3120 | Loss 0.0187 | Accuracy 0.7000 | TEPS 33829.00
Epoch 00061 | Time(s) 0.3114 | Loss 0.0337 | Accuracy 0.6733 | TEPS 33896.18
Epoch 00062 | Time(s) 0.3108 | Loss 0.0540 | Accuracy 0.7067 | TEPS 33960.84
Epoch 00063 | Time(s) 0.3103 | Loss 0.0555 | Accuracy 0.6967 | TEPS 34020.24
Epoch 00064 | Time(s) 0.3098 | Loss 0.0290 | Accuracy 0.6900 | TEPS 34078.19
Epoch 00065 | Time(s) 0.3093 | Loss 0.0245 | Accuracy 0.7133 | TEPS 34132.81
Epoch 00066 | Time(s) 0.3088 | Loss 0.0382 | Accuracy 0.7067 | TEPS 34186.73
Epoch 00067 | Time(s) 0.3083 | Loss 0.0224 | Accuracy 0.7167 | TEPS 34240.41
Epoch 00068 | Time(s) 0.3078 | Loss 0.0160 | Accuracy 0.7167 | TEPS 34292.07
Epoch 00069 | Time(s) 0.3074 | Loss 0.0157 | Accuracy 0.6900 | TEPS 34340.59
Epoch 00070 | Time(s) 0.3070 | Loss 0.0260 | Accuracy 0.7100 | TEPS 34389.20
Epoch 00071 | Time(s) 0.3065 | Loss 0.0153 | Accuracy 0.7267 | TEPS 34435.19
Epoch 00072 | Time(s) 0.3062 | Loss 0.0284 | Accuracy 0.7133 | TEPS 34479.32
Epoch 00073 | Time(s) 0.3058 | Loss 0.0221 | Accuracy 0.7167 | TEPS 34521.67
Epoch 00074 | Time(s) 0.3054 | Loss 0.0253 | Accuracy 0.7233 | TEPS 34566.44
Epoch 00075 | Time(s) 0.3050 | Loss 0.0163 | Accuracy 0.7067 | TEPS 34607.01
Epoch 00076 | Time(s) 0.3052 | Loss 0.0178 | Accuracy 0.7233 | TEPS 34584.88
Epoch 00077 | Time(s) 0.3049 | Loss 0.0240 | Accuracy 0.7067 | TEPS 34622.11
Epoch 00078 | Time(s) 0.3045 | Loss 0.0809 | Accuracy 0.7267 | TEPS 34665.02
Epoch 00079 | Time(s) 0.3041 | Loss 0.0566 | Accuracy 0.7333 | TEPS 34706.89
Epoch 00080 | Time(s) 0.3037 | Loss 0.0679 | Accuracy 0.7100 | TEPS 34755.74
Epoch 00081 | Time(s) 0.3033 | Loss 0.0333 | Accuracy 0.7300 | TEPS 34801.41
Epoch 00082 | Time(s) 0.3029 | Loss 0.0321 | Accuracy 0.7267 | TEPS 34846.30
Epoch 00083 | Time(s) 0.3025 | Loss 0.0253 | Accuracy 0.7100 | TEPS 34890.35
Epoch 00084 | Time(s) 0.3022 | Loss 0.0393 | Accuracy 0.7400 | TEPS 34932.74
Epoch 00085 | Time(s) 0.3018 | Loss 0.0207 | Accuracy 0.7167 | TEPS 34976.52
Epoch 00086 | Time(s) 0.3015 | Loss 0.0151 | Accuracy 0.7433 | TEPS 35012.38
Epoch 00087 | Time(s) 0.3012 | Loss 0.0129 | Accuracy 0.7500 | TEPS 35049.73
Epoch 00088 | Time(s) 0.3014 | Loss 0.0158 | Accuracy 0.7500 | TEPS 35020.74
Epoch 00089 | Time(s) 0.3017 | Loss 0.0193 | Accuracy 0.7300 | TEPS 34988.06
Epoch 00090 | Time(s) 0.3019 | Loss 0.0348 | Accuracy 0.7033 | TEPS 34968.26
Epoch 00091 | Time(s) 0.3020 | Loss 0.0126 | Accuracy 0.7500 | TEPS 34949.93
Epoch 00092 | Time(s) 0.3022 | Loss 0.0113 | Accuracy 0.7667 | TEPS 34931.42
Epoch 00093 | Time(s) 0.3023 | Loss 0.0111 | Accuracy 0.7633 | TEPS 34914.06
Epoch 00094 | Time(s) 0.3025 | Loss 0.0146 | Accuracy 0.7333 | TEPS 34896.31
Epoch 00095 | Time(s) 0.3026 | Loss 0.0172 | Accuracy 0.7433 | TEPS 34878.75
Epoch 00096 | Time(s) 0.3028 | Loss 0.0112 | Accuracy 0.7533 | TEPS 34859.78
Epoch 00097 | Time(s) 0.3030 | Loss 0.0097 | Accuracy 0.7433 | TEPS 34842.60
Epoch 00098 | Time(s) 0.3031 | Loss 0.0142 | Accuracy 0.7500 | TEPS 34825.57
Epoch 00099 | Time(s) 0.3033 | Loss 0.0148 | Accuracy 0.7233 | TEPS 34808.39
Epoch 00100 | Time(s) 0.3034 | Loss 0.0126 | Accuracy 0.7600 | TEPS 34792.15
Epoch 00101 | Time(s) 0.3035 | Loss 0.0107 | Accuracy 0.7500 | TEPS 34777.40
Epoch 00102 | Time(s) 0.3037 | Loss 0.0081 | Accuracy 0.7600 | TEPS 34763.09
Epoch 00103 | Time(s) 0.3038 | Loss 0.0129 | Accuracy 0.7500 | TEPS 34748.64
Epoch 00104 | Time(s) 0.3039 | Loss 0.0097 | Accuracy 0.7500 | TEPS 34734.12
Epoch 00105 | Time(s) 0.3040 | Loss 0.0084 | Accuracy 0.7500 | TEPS 34718.30
Epoch 00106 | Time(s) 0.3042 | Loss 0.0133 | Accuracy 0.7367 | TEPS 34701.78
Epoch 00107 | Time(s) 0.3043 | Loss 0.0085 | Accuracy 0.7700 | TEPS 34688.19
Epoch 00108 | Time(s) 0.3044 | Loss 0.0122 | Accuracy 0.7600 | TEPS 34674.32
Epoch 00109 | Time(s) 0.3046 | Loss 0.0130 | Accuracy 0.7700 | TEPS 34660.66
Epoch 00110 | Time(s) 0.3047 | Loss 0.0130 | Accuracy 0.7500 | TEPS 34646.80
Epoch 00111 | Time(s) 0.3048 | Loss 0.0252 | Accuracy 0.7467 | TEPS 34634.26
Epoch 00112 | Time(s) 0.3049 | Loss 0.0196 | Accuracy 0.7433 | TEPS 34621.93
Epoch 00113 | Time(s) 0.3050 | Loss 0.0130 | Accuracy 0.7567 | TEPS 34609.54
Epoch 00114 | Time(s) 0.3051 | Loss 0.0121 | Accuracy 0.7400 | TEPS 34597.31
Epoch 00115 | Time(s) 0.3052 | Loss 0.0148 | Accuracy 0.7367 | TEPS 34585.32
Epoch 00116 | Time(s) 0.3053 | Loss 0.0125 | Accuracy 0.7433 | TEPS 34573.64
Epoch 00117 | Time(s) 0.3054 | Loss 0.0140 | Accuracy 0.7333 | TEPS 34562.07
Epoch 00118 | Time(s) 0.3055 | Loss 0.0166 | Accuracy 0.7467 | TEPS 34550.82
Epoch 00119 | Time(s) 0.3056 | Loss 0.0134 | Accuracy 0.7567 | TEPS 34539.08
training complete
In [35]:
acc = get_accuracy(last_output[test_mask], labels_test)
print("Test Accuracy {:.4f}".format(acc))
Test Accuracy 0.6830

Not bad! See how high you can get the accuracy with some tweaking. Compare against the state-of-the-art here: https://paperswithcode.com/sota/node-classification-on-cora

We can plot an animation of the predictions during training (although we are limited to 2D)

In [26]:
# one colour for each class
colors = ['red', 'green', 'blue', 'yellow', 'orange', 'purple', 'pink']
# to keep the graph small lets only consider training nodes
train_nodes = train_mask.cpu().numpy()
non_train = np.ones(len(train_nodes))
non_train[train_nodes] = 0
non_train = np.where(non_train)[0]
nx_g = model.g.to_networkx()
nx_g.remove_nodes_from(non_train)
rn_nodes = range(nx_g.number_of_nodes())
In [27]:
def draw_epoch(i):
    current_colors = []
    if gpu >= 0:
        logits = all_train_logits[i].detach().cpu().numpy()
    else:
        logits = all_train_logits[i].detach().numpy()
        
    max_ix = logits.argmax(axis=1)
    
    # choose x, y position based on the magntude of their highest 
    #min_ix = max_ix - 1
    #pos = {n: [logits[n, max_ix[n]], logits[n, min_ix[n]]] for n in rn_nodes}
    #node_size = 100
    # x=node_index, y = certainty, color=class 
    #pos = {n: [n, logits[n, max_ix[n]]] for n in rn_nodes}
    #node_size = 100
    
    # x=node_index, y = class, size = certainty
    pos = {n: [n, max_ix[n]] for n in rn_nodes}
    node_size = logits.max(axis=1) * 100
    
    cols = [colors[max_ix[n]] for n in rn_nodes]
    
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw_networkx(nx_g, pos, node_color=cols,
            with_labels=True, node_size=node_size, ax=ax, 
            edge_colors='purple', arrows=False, alpha=0.6)

fig = plt.figure(dpi=100)
fig.clf()
ax = fig.subplots()
draw_epoch(0)  # draw the prediction of the first epoch
plt.close()
In [28]:
ani = animation.FuncAnimation(fig, draw_epoch, frames=len(all_train_logits), interval=100)
In [29]:
HTML(ani.to_jshtml())
Out[29]:

Note how the separation of the nodes into classes improves with more training epochs.

In [ ]: